from typing import List, Optional

from pydantic import Field

from llmstack.common.utils.utils import get_key_or_raise
from llmstack.processors.providers.api_processor_interface import (
    ApiProcessorInterface,
    ApiProcessorSchema,
)
from llmstack.processors.providers.replicate.utils import fetch_data_from_api


class Blip2Input(ApiProcessorSchema):
    generate_caption: Optional[bool] = Field(
        False,
        description="Generate caption for image",
    )
    query: Optional[str] = Field(
        "",
        description="Question to ask about this image. Leave blank for captioning",
    )


class Blip2Configuration(ApiProcessorSchema):
    use_nucleus_sampling: Optional[bool] = Field(
        False,
        description="Use nucleus sampling",
    )
    temperature: Optional[float] = Field(
        0.7,
        description="Temperature for use with nucleus sampling (minimum: 0.5; maximum: 1)",
    )
    version: str = Field(
        "4b32258c42e9efd4288bb9910bc532a69727f9acd26aa08e175713a0a857a608",
        description="Model version",
    )
    sync_mode: Optional[bool] = Field(
        False,
        description="Run in synchronous mode",
    )


class Blip2Output(ApiProcessorSchema):
    generations: List[dict] = Field(
        default=[],
        description="The completions generated by the model.",
    )
    _api_response: dict = Field(
        default={},
        description="The raw response from the API.",
    )


class Blip2(
    ApiProcessorInterface[
        Blip2Input,
        Blip2Output,
        Blip2Configuration,
    ],
):
    def name() -> str:
        return "replicate/blip2"

    def process(self) -> dict:
        _env = self._env
        replicate_api_key = get_key_or_raise(
            _env,
            "replicate_api_key",
            "No replicate_api_key found in _env",
        )

        url = "https://api.replicate.com/v1/predictions"
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {replicate_api_key}",
        }
        configuration = self._config.dict().pop("sync_mode")
        api_params = {
            "version": self._config.version,
            "input": {**configuration.pop("version"), **input.dict()},
        }
        api_params.pop("_env", None)

        api_response = fetch_data_from_api(url, api_params, headers)

        if api_response.ok:
            json_api_response = api_response.json()
            result = {
                "async": json_api_response["urls"],
                "_response": {
                    "_api_response": json_api_response,
                },
            }
            return result
        else:
            raise Exception(f"Error calling OpenAI API: {api_response.text}")
